{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fbf51cdf-39fc-4267-b9ed-f9899f398ac5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install torch_geometric==2.3.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4cfb810b-2979-432a-9186-3a03f6ae086c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import os.path as osp\n",
    "from typing import Any, Dict, Optional\n",
    "\n",
    "import torch\n",
    "from torch.nn import (\n",
    "    BatchNorm1d,\n",
    "    Embedding,\n",
    "    Linear,\n",
    "    ModuleList,\n",
    "    ReLU,\n",
    "    Sequential,\n",
    ")\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.datasets import ZINC\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.nn import GINEConv, global_add_pool\n",
    "import inspect\n",
    "from typing import Any, Dict, Optional\n",
    "\n",
    "import torch.nn.functional as F\n",
    "from torch import Tensor\n",
    "from torch.nn import Dropout, Linear, Sequential\n",
    "\n",
    "from torch_geometric.nn.conv import MessagePassing\n",
    "from torch_geometric.nn.inits import reset\n",
    "from torch_geometric.nn.resolver import (\n",
    "    activation_resolver,\n",
    "    normalization_resolver,\n",
    ")\n",
    "from torch_geometric.typing import Adj\n",
    "from torch_geometric.utils import to_dense_batch\n",
    "\n",
    "from mamba_ssm import Mamba\n",
    "from torch_geometric.utils import degree, sort_edge_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "49609a76-1d75-438c-b9f0-e88439dd39d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def permute_within_batch(x, batch):\n",
    "    # Enumerate over unique batch indices\n",
    "    unique_batches = torch.unique(batch)\n",
    "    \n",
    "    # Initialize list to store permuted indices\n",
    "    permuted_indices = []\n",
    "\n",
    "    for batch_index in unique_batches:\n",
    "        # Extract indices for the current batch\n",
    "        indices_in_batch = (batch == batch_index).nonzero().squeeze()\n",
    "        \n",
    "        # Permute indices within the current batch\n",
    "        permuted_indices_in_batch = indices_in_batch[torch.randperm(len(indices_in_batch))]\n",
    "        \n",
    "        # Append permuted indices to the list\n",
    "        permuted_indices.append(permuted_indices_in_batch)\n",
    "    \n",
    "    # Concatenate permuted indices into a single tensor\n",
    "    permuted_indices = torch.cat(permuted_indices)\n",
    "\n",
    "    return permuted_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1131a620-e047-4191-a843-f7f285877e89",
   "metadata": {},
   "outputs": [],
   "source": [
    "# path, subset = '/scratch/ssd004/scratch/tsepaole/ZINC_full/', False\n",
    "path, subset = '', True\n",
    "\n",
    "transform = T.AddRandomWalkPE(walk_length=20, attr_name='pe')\n",
    "train_dataset = ZINC(path, subset=subset, split='train', pre_transform=transform)\n",
    "val_dataset = ZINC(path, subset=subset, split='val', pre_transform=transform)\n",
    "test_dataset = ZINC(path, subset=subset, split='test', pre_transform=transform)\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
    "val_loader = DataLoader(val_dataset, batch_size=64)\n",
    "test_loader = DataLoader(test_dataset, batch_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "34dc6712-8c4a-44f7-94e9-d582c5e16bcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPSConv(torch.nn.Module):\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        channels: int,\n",
    "        conv: Optional[MessagePassing],\n",
    "        heads: int = 1,\n",
    "        dropout: float = 0.0,\n",
    "        attn_dropout: float = 0.0,\n",
    "        act: str = 'relu',\n",
    "        att_type: str = 'transformer',\n",
    "        order_by_degree: bool = False,\n",
    "        shuffle_ind: int = 0,\n",
    "        d_state: int = 16,\n",
    "        d_conv: int = 4,\n",
    "        act_kwargs: Optional[Dict[str, Any]] = None,\n",
    "        norm: Optional[str] = 'batch_norm',\n",
    "        norm_kwargs: Optional[Dict[str, Any]] = None,\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.channels = channels\n",
    "        self.conv = conv\n",
    "        self.heads = heads\n",
    "        self.dropout = dropout\n",
    "        self.att_type = att_type\n",
    "        self.shuffle_ind = shuffle_ind\n",
    "        self.order_by_degree = order_by_degree\n",
    "        \n",
    "        assert (self.order_by_degree==True and self.shuffle_ind==0) or (self.order_by_degree==False), f'order_by_degree={self.order_by_degree} and shuffle_ind={self.shuffle_ind}'\n",
    "        \n",
    "        if self.att_type == 'transformer':\n",
    "            self.attn = torch.nn.MultiheadAttention(\n",
    "                channels,\n",
    "                heads,\n",
    "                dropout=attn_dropout,\n",
    "                batch_first=True,\n",
    "            )\n",
    "        if self.att_type == 'mamba':\n",
    "            self.self_attn = Mamba(\n",
    "                d_model=channels,\n",
    "                d_state=d_state,\n",
    "                d_conv=d_conv,\n",
    "                expand=1\n",
    "            )\n",
    "            \n",
    "        self.mlp = Sequential(\n",
    "            Linear(channels, channels * 2),\n",
    "            activation_resolver(act, **(act_kwargs or {})),\n",
    "            Dropout(dropout),\n",
    "            Linear(channels * 2, channels),\n",
    "            Dropout(dropout),\n",
    "        )\n",
    "\n",
    "        norm_kwargs = norm_kwargs or {}\n",
    "        self.norm1 = normalization_resolver(norm, channels, **norm_kwargs)\n",
    "        self.norm2 = normalization_resolver(norm, channels, **norm_kwargs)\n",
    "        self.norm3 = normalization_resolver(norm, channels, **norm_kwargs)\n",
    "\n",
    "        self.norm_with_batch = False\n",
    "        if self.norm1 is not None:\n",
    "            signature = inspect.signature(self.norm1.forward)\n",
    "            self.norm_with_batch = 'batch' in signature.parameters\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n",
    "        if self.conv is not None:\n",
    "            self.conv.reset_parameters()\n",
    "        self.attn._reset_parameters()\n",
    "        reset(self.mlp)\n",
    "        if self.norm1 is not None:\n",
    "            self.norm1.reset_parameters()\n",
    "        if self.norm2 is not None:\n",
    "            self.norm2.reset_parameters()\n",
    "        if self.norm3 is not None:\n",
    "            self.norm3.reset_parameters()\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        x: Tensor,\n",
    "        edge_index: Adj,\n",
    "        batch: Optional[torch.Tensor] = None,\n",
    "        **kwargs,\n",
    "    ) -> Tensor:\n",
    "        r\"\"\"Runs the forward pass of the module.\"\"\"\n",
    "        hs = []\n",
    "        if self.conv is not None:  # Local MPNN.\n",
    "            h = self.conv(x, edge_index, **kwargs)\n",
    "            h = F.dropout(h, p=self.dropout, training=self.training)\n",
    "            h = h + x\n",
    "            if self.norm1 is not None:\n",
    "                if self.norm_with_batch:\n",
    "                    h = self.norm1(h, batch=batch)\n",
    "                else:\n",
    "                    h = self.norm1(h)\n",
    "            hs.append(h)\n",
    "\n",
    "        ### Global attention transformer-style model.\n",
    "        if self.att_type == 'transformer':\n",
    "            h, mask = to_dense_batch(x, batch)\n",
    "            h, _ = self.attn(h, h, h, key_padding_mask=~mask, need_weights=False)\n",
    "            h = h[mask]\n",
    "            \n",
    "        if self.att_type == 'mamba':\n",
    "            \n",
    "            if self.order_by_degree:\n",
    "                deg = degree(edge_index[0], x.shape[0]).to(torch.long)\n",
    "                order_tensor = torch.stack([batch, deg], 1).T\n",
    "                _, x = sort_edge_index(order_tensor, edge_attr=x)\n",
    "                \n",
    "            if self.shuffle_ind == 0:\n",
    "                h, mask = to_dense_batch(x, batch)\n",
    "                h = self.self_attn(h)[mask]\n",
    "            else:\n",
    "                mamba_arr = []\n",
    "                for _ in range(self.shuffle_ind):\n",
    "                    h_ind_perm = permute_within_batch(x, batch)\n",
    "                    h_i, mask = to_dense_batch(x[h_ind_perm], batch)\n",
    "                    h_i = self.self_attn(h_i)[mask][h_ind_perm]\n",
    "                    mamba_arr.append(h_i)\n",
    "                h = sum(mamba_arr) / self.shuffle_ind\n",
    "        ###\n",
    "        \n",
    "        h = F.dropout(h, p=self.dropout, training=self.training)\n",
    "        h = h + x  # Residual connection.\n",
    "        if self.norm2 is not None:\n",
    "            if self.norm_with_batch:\n",
    "                h = self.norm2(h, batch=batch)\n",
    "            else:\n",
    "                h = self.norm2(h)\n",
    "        hs.append(h)\n",
    "\n",
    "        out = sum(hs)  # Combine local and global outputs.\n",
    "\n",
    "        out = out + self.mlp(out)\n",
    "        if self.norm3 is not None:\n",
    "            if self.norm_with_batch:\n",
    "                out = self.norm3(out, batch=batch)\n",
    "            else:\n",
    "                out = self.norm3(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "    def __repr__(self) -> str:\n",
    "        return (f'{self.__class__.__name__}({self.channels}, '\n",
    "                f'conv={self.conv}, heads={self.heads})')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d9fea5c7-9fc9-4388-98fd-8c45787d0abc",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GraphModel(torch.nn.Module):\n",
    "    def __init__(self, channels: int, pe_dim: int, num_layers: int, model_type: str, shuffle_ind: int, d_state: int, d_conv: int, order_by_degree: False):\n",
    "        super().__init__()\n",
    "\n",
    "        self.node_emb = Embedding(28, channels - pe_dim)\n",
    "        self.pe_lin = Linear(20, pe_dim)\n",
    "        self.pe_norm = BatchNorm1d(20)\n",
    "        self.edge_emb = Embedding(4, channels)\n",
    "        self.model_type = model_type\n",
    "        self.shuffle_ind = shuffle_ind\n",
    "        self.order_by_degree = order_by_degree\n",
    "        \n",
    "        self.convs = ModuleList()\n",
    "        for _ in range(num_layers):\n",
    "            nn = Sequential(\n",
    "                Linear(channels, channels),\n",
    "                ReLU(),\n",
    "                Linear(channels, channels),\n",
    "            )\n",
    "            if self.model_type == 'gine':\n",
    "                conv = GINEConv(nn)\n",
    "                \n",
    "            if self.model_type == 'mamba':\n",
    "                conv = GPSConv(channels, GINEConv(nn), heads=4, attn_dropout=0.5,\n",
    "                               att_type='mamba',\n",
    "                               shuffle_ind=self.shuffle_ind,\n",
    "                               order_by_degree=self.order_by_degree,\n",
    "                               d_state=d_state, d_conv=d_conv)\n",
    "                \n",
    "            if self.model_type == 'transformer':\n",
    "                conv = GPSConv(channels, GINEConv(nn), heads=4, attn_dropout=0.5, att_type='transformer')\n",
    "                \n",
    "            # conv = GINEConv(nn)\n",
    "            self.convs.append(conv)\n",
    "\n",
    "        self.mlp = Sequential(\n",
    "            Linear(channels, channels // 2),\n",
    "            ReLU(),\n",
    "            Linear(channels // 2, channels // 4),\n",
    "            ReLU(),\n",
    "            Linear(channels // 4, 1),\n",
    "        )\n",
    "\n",
    "    def forward(self, x, pe, edge_index, edge_attr, batch):\n",
    "        x_pe = self.pe_norm(pe)\n",
    "        x = torch.cat((self.node_emb(x.squeeze(-1)), self.pe_lin(x_pe)), 1)\n",
    "        edge_attr = self.edge_emb(edge_attr)\n",
    "\n",
    "        for conv in self.convs:\n",
    "            if self.model_type == 'gine':\n",
    "                x = conv(x, edge_index, edge_attr=edge_attr)\n",
    "            else:\n",
    "                x = conv(x, edge_index, batch, edge_attr=edge_attr)\n",
    "                \n",
    "        x = global_add_pool(x, batch)\n",
    "        return self.mlp(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "85c6d04a-4eff-45ef-be39-93e4b520a967",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    \n",
    "    total_loss = 0\n",
    "    for data in train_loader:\n",
    "        data = data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        out = model(data.x, data.pe, data.edge_index, data.edge_attr,\n",
    "                    data.batch)\n",
    "        loss = (out.squeeze() - data.y).abs().mean()\n",
    "        loss.backward()\n",
    "        total_loss += loss.item() * data.num_graphs\n",
    "        optimizer.step()\n",
    "    return total_loss / len(train_loader.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0ddc967c-1c69-4d26-aa30-20c212c92c81",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def test(loader):\n",
    "    model.eval()\n",
    "\n",
    "    total_error = 0\n",
    "    for data in loader:\n",
    "        data = data.to(device)\n",
    "        # print(data.x.shape)\n",
    "        out = model(data.x, data.pe, data.edge_index, data.edge_attr,\n",
    "                    data.batch)\n",
    "        total_error += (out.squeeze() - data.y).abs().sum().item()\n",
    "    return total_error / len(loader.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "03d5de1c-0c2d-4fd4-81db-58f5f4257764",
   "metadata": {},
   "outputs": [],
   "source": [
    "# it.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "cf9b7ab1-54e5-415d-be4a-c43742c33608",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 01, Loss: 0.7204, Val: 0.7170, Test: 0.7365\n",
      "Epoch: 02, Loss: 0.6121, Val: 0.5867, Test: 0.6100\n",
      "Epoch: 03, Loss: 0.5747, Val: 0.7788, Test: 0.7760\n",
      "Epoch: 04, Loss: 0.5386, Val: 0.5892, Test: 0.5950\n",
      "Epoch: 05, Loss: 0.5261, Val: 0.5417, Test: 0.5383\n",
      "Epoch: 06, Loss: 0.4998, Val: nan, Test: 0.6481\n",
      "Epoch: 07, Loss: 0.4902, Val: 0.4364, Test: 0.4555\n",
      "Epoch: 08, Loss: 0.4680, Val: 0.5670, Test: 0.6019\n",
      "Epoch: 09, Loss: 0.4335, Val: 0.4619, Test: 0.4543\n",
      "Epoch: 10, Loss: 0.4447, Val: 0.3931, Test: 0.4147\n",
      "Epoch: 11, Loss: 0.4289, Val: 0.4258, Test: 0.4431\n",
      "Epoch: 12, Loss: 0.4220, Val: 0.4031, Test: 0.4058\n",
      "Epoch: 13, Loss: 0.4112, Val: 0.4105, Test: 0.4198\n",
      "Epoch: 14, Loss: 0.4066, Val: 0.4957, Test: 0.4886\n",
      "Epoch: 15, Loss: 0.3990, Val: 0.4142, Test: 0.3974\n",
      "Epoch: 16, Loss: 0.3858, Val: 0.5763, Test: nan\n",
      "Epoch: 17, Loss: 0.3757, Val: 0.3922, Test: nan\n",
      "Epoch: 18, Loss: 0.3749, Val: 0.3993, Test: 0.4512\n",
      "Epoch: 19, Loss: 0.3724, Val: nan, Test: nan\n",
      "Epoch: 20, Loss: 0.3494, Val: 0.3795, Test: nan\n",
      "Epoch: 21, Loss: 0.3509, Val: 0.3653, Test: 0.3522\n",
      "Epoch: 22, Loss: 0.3329, Val: 0.3485, Test: nan\n",
      "Epoch: 23, Loss: 0.3336, Val: 0.3418, Test: 255.8957\n",
      "Epoch: 24, Loss: 0.3304, Val: 0.3540, Test: 0.3364\n",
      "Epoch: 25, Loss: 0.3282, Val: 0.3345, Test: nan\n",
      "Epoch: 26, Loss: 0.3101, Val: 0.3274, Test: nan\n",
      "Epoch: 27, Loss: 0.3247, Val: 0.3062, Test: nan\n",
      "Epoch: 28, Loss: 0.3180, Val: nan, Test: nan\n",
      "Epoch: 29, Loss: 0.3144, Val: 0.3217, Test: 0.2960\n",
      "[0.736513614654541, 0.6099917201995849, 0.7760194625854492, 0.5949835357666016, 0.5382931175231933, 0.6480579833984375, 0.45554696464538574, 0.6018772258758545, 0.45431803703308105, 0.4146977252960205, 0.4431104040145874, 0.4057636137008667, 0.41979767227172854, 0.4885702114105225, 0.39735649681091306, nan, nan, 0.451167293548584, nan, nan, 0.3521624622344971, nan, 255.89569853305818, 0.33635227489471436, nan, nan, nan, nan, 0.296005051612854]\n"
     ]
    }
   ],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "model = GraphModel(channels=64, pe_dim=8, num_layers=10,\n",
    "                   model_type='mamba',\n",
    "                   shuffle_ind=0, order_by_degree=True,\n",
    "                   d_conv=4, d_state=16,\n",
    "                  ).to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)\n",
    "scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,\n",
    "                              min_lr=0.00001)\n",
    "arr = []\n",
    "for epoch in range(1, 30):\n",
    "    loss = train()\n",
    "    val_mae = test(val_loader)\n",
    "    test_mae = test(test_loader)\n",
    "    scheduler.step(val_mae)\n",
    "    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '\n",
    "          f'Test: {test_mae:.4f}')\n",
    "    arr.append(test_mae)\n",
    "ordering = arr\n",
    "print(ordering)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dadbc115-6b16-4266-af79-75ec97b304a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 01, Loss: 0.6478, Val: 0.5409, Test: 0.5737\n",
      "Epoch: 02, Loss: 0.5205, Val: 0.4522, Test: 0.4622\n",
      "Epoch: 03, Loss: 0.4889, Val: 0.5605, Test: 0.5807\n",
      "Epoch: 04, Loss: 0.4440, Val: 0.3877, Test: 0.3950\n",
      "Epoch: 05, Loss: 0.4151, Val: 0.4781, Test: 0.4825\n",
      "Epoch: 06, Loss: 0.4200, Val: 0.3819, Test: 0.3898\n",
      "Epoch: 07, Loss: 0.3929, Val: 0.4256, Test: 0.4256\n",
      "Epoch: 08, Loss: 0.3695, Val: 0.3649, Test: 0.3617\n",
      "Epoch: 09, Loss: 0.3680, Val: 0.4223, Test: 0.3876\n",
      "Epoch: 10, Loss: 0.3569, Val: 0.3277, Test: 0.3216\n"
     ]
    }
   ],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "model = GraphModel(channels=64, pe_dim=8, num_layers=10,\n",
    "                   model_type='mamba',\n",
    "                   shuffle_ind=1, order_by_degree=False,\n",
    "                   d_conv=4, d_state=16,\n",
    "                  ).to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)\n",
    "scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,\n",
    "                              min_lr=0.00001)\n",
    "arr = []\n",
    "for epoch in range(1, 30):\n",
    "    loss = train()\n",
    "    val_mae = test(val_loader)\n",
    "    test_mae = test(test_loader)\n",
    "    scheduler.step(val_mae)\n",
    "    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '\n",
    "          f'Test: {test_mae:.4f}')\n",
    "    arr.append(test_mae)\n",
    "permute = arr\n",
    "print(permute)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15ab378c-1563-4177-8078-61b036aa9c6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import matplotlib.pyplot as plt\n",
    "# import pandas as pd\n",
    "\n",
    "# import numpy as np\n",
    "# res_df = pd.read_csv('30_ep_res.csv')\n",
    "\n",
    "# WINDOW = 1\n",
    "# fig, ax = plt.subplots(1, figsize=(15,5))\n",
    "\n",
    "# for col in res_df.columns:\n",
    "#     plt.plot(res_df[col].clip(0,0.7).rolling(WINDOW, min_periods=1).mean(), label=col)\n",
    "\n",
    "# plt.legend()\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15426d6e-d6a4-4b46-80c8-dae456caaa4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig, ax = plt.subplots(1, figsize=(15,5))\n",
    "\n",
    "plt.plot(permute[:20], label='permute')\n",
    "plt.plot(ordering[:20], label='order')\n",
    "\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67d97587-f99a-46b9-9f7b-74d81d42fca0",
   "metadata": {},
   "outputs": [],
   "source": [
    "fffffffffff"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0c5b0c6-64be-4a4b-9e3f-62a76355dce4",
   "metadata": {},
   "source": [
    "# Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c05860c-e4cc-4e49-ad4c-e746a22ea8e0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "41bf4f39-8190-4b24-b391-b31b161eca7b",
   "metadata": {},
   "source": [
    "# Tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bc68822-3415-4685-9e62-371dc919bd93",
   "metadata": {},
   "outputs": [],
   "source": [
    "ggggggggggg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47643725-13bc-4e29-9bc5-0d748a10d951",
   "metadata": {},
   "outputs": [],
   "source": [
    "it = next(iter(train_loader))\n",
    "# h, mask = to_dense_batch(it.x, it.batch)\n",
    "# it.x.shape, h.shape, mask.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1955eeda-6fc6-4f4f-8c97-b0643a9c3b6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "deg.dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d359e9b1-e3ba-4ff8-8ff6-3a17e7deb306",
   "metadata": {},
   "outputs": [],
   "source": [
    "it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f7447ec-f9ab-43de-8a39-f2707f4fb020",
   "metadata": {},
   "outputs": [],
   "source": [
    "it.edge_index[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e281930c-5708-4cf9-866f-c1d2efde207c",
   "metadata": {},
   "outputs": [],
   "source": [
    "it.to(device)\n",
    "out = model(it.x, it.pe, it.edge_index, it.edge_attr,\n",
    "                    it.batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85e47e6b-2721-4eeb-9c70-c8fc269cc4a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = torch.tensor([0,0,0,1,1,1,1])\n",
    "x = torch.tensor([0,1,2,3,4,5,6])\n",
    "batch.shape, x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef569b66-a435-4f2a-96ad-f33177ca9435",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def permute_within_batch(x, batch):\n",
    "    # Enumerate over unique batch indices\n",
    "    unique_batches = torch.unique(batch)\n",
    "    \n",
    "    # Initialize list to store permuted indices\n",
    "    permuted_indices = []\n",
    "\n",
    "    for batch_index in unique_batches:\n",
    "        # Extract indices for the current batch\n",
    "        indices_in_batch = (batch == batch_index).nonzero().squeeze()\n",
    "        \n",
    "        # Permute indices within the current batch\n",
    "        permuted_indices_in_batch = indices_in_batch[torch.randperm(len(indices_in_batch))]\n",
    "        \n",
    "        # Append permuted indices to the list\n",
    "        permuted_indices.append(permuted_indices_in_batch)\n",
    "\n",
    "    # Concatenate permuted indices into a single tensor\n",
    "    permuted_indices = torch.cat(permuted_indices)\n",
    "\n",
    "    return permuted_indices\n",
    "\n",
    "# Example usage\n",
    "batch = torch.tensor([0, 0, 0, 1, 1, 1, 1])\n",
    "x = torch.tensor([0, 10, 20, 30, 40, 50, 60])\n",
    "\n",
    "# Get permuted indices\n",
    "permuted_indices = permute_within_batch(x, batch)\n",
    "\n",
    "# Use permuted indices to get the permuted tensor\n",
    "permuted_x = x[permuted_indices]\n",
    "\n",
    "print(\"Original x:\", x)\n",
    "print(\"Permuted x:\", permuted_x)\n",
    "print(\"Permuted indices:\", permuted_indices)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "669db4fc-0bf5-49a9-a701-6280c793c74c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eca35cc4-faaa-417d-abb0-8da383ecef2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "mask[0].sum(), (it.batch==0).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e624feb-987b-4cc7-b7a0-388d2b6f121e",
   "metadata": {},
   "outputs": [],
   "source": [
    "self_attn = Mamba(d_model=64, # Model dimension d_model\n",
    "                                d_state=16,  # SSM state expansion factor\n",
    "                                d_conv=4,    # Local convolution width\n",
    "                                expand=1,    # Block expansion factor\n",
    "                            )\n",
    "print(sum(p.numel() for p in self_attn.parameters() if p.requires_grad), sum(p.numel() for p in self_attn.parameters()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61d5c45a-6d97-4659-99d5-14c63cec9818",
   "metadata": {},
   "outputs": [],
   "source": [
    "self_attn = Mamba(d_model=64, # Model dimension d_model\n",
    "                                d_state=8,  # SSM state expansion factor\n",
    "                                d_conv=2,    # Local convolution width\n",
    "                                expand=1,    # Block expansion factor\n",
    "                            )\n",
    "print(sum(p.numel() for p in self_attn.parameters() if p.requires_grad), sum(p.numel() for p in self_attn.parameters()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1359039-8fa5-4f01-a4c6-2171dc188300",
   "metadata": {},
   "outputs": [],
   "source": [
    "self_attn = Mamba(d_model=64, # Model dimension d_model\n",
    "                                d_state=16,  # SSM state expansion factor\n",
    "                                d_conv=8,    # Local convolution width\n",
    "                                expand=1,    # Block expansion factor\n",
    "                            )\n",
    "print(sum(p.numel() for p in self_attn.parameters() if p.requires_grad), sum(p.numel() for p in self_attn.parameters()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bd4f3f9-2eef-4731-8ae0-a00e7fdaa6b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "self_attn = torch.nn.MultiheadAttention(\n",
    "                64,\n",
    "                4,\n",
    "                dropout=0.5,\n",
    "                batch_first=True,\n",
    "            )\n",
    "print(sum(p.numel() for p in self_attn.parameters() if p.requires_grad), sum(p.numel() for p in self_attn.parameters()))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mols",
   "language": "python",
   "name": "mols"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
